In [1]:
import numpy as np
import scipy
import scipy.stats
import torch as t

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

from torch.nn import Sequential, Linear, ReLU, LeakyReLU, Dropout, Sigmoid
In [2]:
%matplotlib inline
In [3]:
device=t.device('cpu') #Overrride the above device choice

Generate the sample 2D distribution: uniform from unit circle.

In [4]:
angle = np.random.uniform(-np.pi,np.pi,(1000,1)).astype('float32')
data = np.concatenate((np.cos(angle), np.sin(angle)),axis=1)
plt.scatter(data[:,0], data[:,1])
Out[4]:
<matplotlib.collections.PathCollection at 0x7fe870eef950>

GAN implementation

In [5]:
discriminator = Sequential(
    Linear(2,50),  
    LeakyReLU(0.2),
    Linear(50, 1), 
    Sigmoid()
) #dummy discriminator: please subsitute you own implementation 
In [6]:
generator = Sequential(
    Linear(2,2000),
    LeakyReLU(0.1),
    
    Linear(2000,1000),
    LeakyReLU(0.1),
    
    Linear(1000,500),
    LeakyReLU(0.1),
    
    Linear(500,200),
    LeakyReLU(0.1),
    
    Linear(200,100),
    LeakyReLU(0.1),
    
    Linear(100,100),
    LeakyReLU(0.1),
    
    Linear(100,50),
    LeakyReLU(0.1),
    
    Linear(50, 2),
    LeakyReLU(0.1),
)# dummy generator: please subsitute you own implementation 
In [7]:
discriminator = discriminator.to(device) 
generator= generator.to(device)

d_optimizer = t.optim.Adam(discriminator.parameters(), lr=0.001)
g_optimizer = t.optim.Adam(generator.parameters(), lr=0.0005)

loss = t.nn.BCELoss()

Starting here

In [8]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = t.ones(size, 1)

    return data

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = t.zeros(size, 1)

    return data
In [9]:
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error
    return error_real + error_fake
In [10]:
def train_generator(optimizer, fake_data):
    # 2. Train Generator
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error
In [11]:
def show(noise, fake, real):
    
    fig, axs = plt.subplots(1,3, figsize = (12,4))
    
    axs[0].set_title("noise")
    axs[1].set_title("fake")
    axs[2].set_title("real")
    
    axs[0].scatter(noise.data.cpu().numpy()[:,0],noise.data.cpu().numpy()[:,1], color = 'gray')
    axs[1].scatter(fake.data.cpu().numpy()[:,0], fake.data.cpu().numpy()[:,1], color = 'red')
    axs[2].scatter(real.data.cpu().numpy()[:,0], real.data.cpu().numpy()[:,1], color = 'green')

    fig.tight_layout()
    plt.show()
In [12]:
num_epochs = 20000
sample_size = 500

for epoch in range(num_epochs):
    
    # real data
    angle = np.random.uniform(-np.pi,np.pi,(sample_size,1)).astype('float32')
    data = np.concatenate((np.cos(angle), np.sin(angle)),axis=1)
    real_data = t.from_numpy(data)
    
    # train discriminator:
    noise = t.empty(sample_size,2, device=device).uniform_(-1,1)
    fake_data = generator(noise)#.detach()    
    d_error = train_discriminator(d_optimizer, real_data, fake_data)
    
    # train generator       - tryu with detach and without  
    noise = t.empty(sample_size,2, device=device).uniform_(-1,1)
    fake_data = generator(noise)#.detach()    
    g_error = train_generator(g_optimizer, fake_data)

     
    if epoch%100 ==0:
        show(noise, fake_data, real_data)
        print(f"Epoch: {epoch} G-error: {g_error.item()} D-error: {d_error.item()}")
Epoch: 0 G-error: 0.6800443530082703 D-error: 1.3687632083892822
Epoch: 100 G-error: 0.6589701771736145 D-error: 1.4028167724609375
Epoch: 200 G-error: 1.048340916633606 D-error: 1.3492401838302612
Epoch: 300 G-error: 0.7740062475204468 D-error: 1.3921338319778442
Epoch: 400 G-error: 0.7547128796577454 D-error: 1.3697733879089355
Epoch: 500 G-error: 0.7021133303642273 D-error: 1.4010388851165771
Epoch: 600 G-error: 0.7271661162376404 D-error: 1.3646022081375122
Epoch: 700 G-error: 0.7324806451797485 D-error: 1.3481385707855225
Epoch: 800 G-error: 0.7412444949150085 D-error: 1.3302277326583862
Epoch: 900 G-error: 0.7345081567764282 D-error: 1.3351585865020752
Epoch: 1000 G-error: 0.6736772060394287 D-error: 1.3868682384490967
Epoch: 1100 G-error: 0.7326403856277466 D-error: 1.3613433837890625
Epoch: 1200 G-error: 0.7209685444831848 D-error: 1.3642432689666748
Epoch: 1300 G-error: 0.7151890397071838 D-error: 1.3700217008590698
Epoch: 1400 G-error: 0.6372514963150024 D-error: 1.452711582183838
Epoch: 1500 G-error: 0.7240999937057495 D-error: 1.3660603761672974
Epoch: 1600 G-error: 0.6875603199005127 D-error: 1.3874924182891846
Epoch: 1700 G-error: 0.7180604934692383 D-error: 1.3417754173278809
Epoch: 1800 G-error: 0.7394790053367615 D-error: 1.3310604095458984
Epoch: 1900 G-error: 0.7392826676368713 D-error: 1.3246874809265137
Epoch: 2000 G-error: 0.7488161325454712 D-error: 1.3375381231307983
Epoch: 2100 G-error: 0.4169638454914093 D-error: 1.82957124710083
Epoch: 2200 G-error: 0.735609233379364 D-error: 1.3366308212280273
Epoch: 2300 G-error: 0.6843637228012085 D-error: 1.414646863937378
Epoch: 2400 G-error: 0.7691243290901184 D-error: 1.3180115222930908
Epoch: 2500 G-error: 0.675849437713623 D-error: 1.3814139366149902
Epoch: 2600 G-error: 0.6601279973983765 D-error: 1.4185190200805664
Epoch: 2700 G-error: 0.7507320046424866 D-error: 1.3392436504364014
Epoch: 2800 G-error: 0.6677773594856262 D-error: 1.4258625507354736
Epoch: 2900 G-error: 0.7376827001571655 D-error: 1.336186408996582
Epoch: 3000 G-error: 0.7444126605987549 D-error: 1.3581123352050781
Epoch: 3100 G-error: 0.6467046141624451 D-error: 1.4258816242218018
Epoch: 3200 G-error: 0.7420340180397034 D-error: 1.3120784759521484
Epoch: 3300 G-error: 0.649178683757782 D-error: 1.4638919830322266
Epoch: 3400 G-error: 0.7360811233520508 D-error: 1.3485981225967407
Epoch: 3500 G-error: 0.7326514720916748 D-error: 1.3362936973571777
Epoch: 3600 G-error: 0.751526415348053 D-error: 1.33856201171875
Epoch: 3700 G-error: 0.7326602935791016 D-error: 1.360972285270691
Epoch: 3800 G-error: 0.7385423183441162 D-error: 1.3264018297195435
Epoch: 3900 G-error: 0.7342178821563721 D-error: 1.383987307548523
Epoch: 4000 G-error: 0.6941207051277161 D-error: 1.3944238424301147
Epoch: 4100 G-error: 0.7012629508972168 D-error: 1.3843698501586914
Epoch: 4200 G-error: 0.670846700668335 D-error: 1.421783685684204
Epoch: 4300 G-error: 0.6844713687896729 D-error: 1.4204260110855103
Epoch: 4400 G-error: 0.6663187146186829 D-error: 1.411177635192871
Epoch: 4500 G-error: 0.6838784217834473 D-error: 1.4044921398162842
Epoch: 4600 G-error: 0.5590541958808899 D-error: 1.556544303894043
Epoch: 4700 G-error: 0.7181574106216431 D-error: 1.3384222984313965
Epoch: 4800 G-error: 0.5954300761222839 D-error: 1.4596683979034424
Epoch: 4900 G-error: 0.5647479295730591 D-error: 1.5075385570526123
Epoch: 5000 G-error: 0.7464523911476135 D-error: 1.4135409593582153
Epoch: 5100 G-error: 0.6499063968658447 D-error: 1.434289813041687
Epoch: 5200 G-error: 0.6963269710540771 D-error: 1.3514609336853027
Epoch: 5300 G-error: 0.5399724841117859 D-error: 1.5355144739151
Epoch: 5400 G-error: 0.7787526845932007 D-error: 1.2767560482025146
Epoch: 5500 G-error: 0.667171061038971 D-error: 1.386183261871338
Epoch: 5600 G-error: 0.7896753549575806 D-error: 1.2795605659484863
Epoch: 5700 G-error: 0.746955156326294 D-error: 1.3611514568328857
Epoch: 5800 G-error: 0.7185946106910706 D-error: 1.3872957229614258
Epoch: 5900 G-error: 0.6583539247512817 D-error: 1.4140799045562744
Epoch: 6000 G-error: 0.7775381207466125 D-error: 1.321088433265686
Epoch: 6100 G-error: 0.7395948171615601 D-error: 1.3780407905578613
Epoch: 6200 G-error: 0.718389630317688 D-error: 1.3854844570159912
Epoch: 6300 G-error: 0.7245966196060181 D-error: 1.3570529222488403
Epoch: 6400 G-error: 0.6487470865249634 D-error: 1.4664068222045898
Epoch: 6500 G-error: 0.6996126174926758 D-error: 1.3884212970733643
Epoch: 6600 G-error: 0.7164443731307983 D-error: 1.3683061599731445
Epoch: 6700 G-error: 0.7423868179321289 D-error: 1.2930257320404053
Epoch: 6800 G-error: 0.6886661052703857 D-error: 1.4513261318206787
Epoch: 6900 G-error: 0.6121377348899841 D-error: 1.4645928144454956
Epoch: 7000 G-error: 0.7688461542129517 D-error: 1.335508108139038
Epoch: 7100 G-error: 0.7865986824035645 D-error: 1.2780680656433105
Epoch: 7200 G-error: 0.8093165755271912 D-error: 1.2844746112823486
Epoch: 7300 G-error: 0.7874767184257507 D-error: 1.3302700519561768
Epoch: 7400 G-error: 0.7532477974891663 D-error: 1.3300714492797852
Epoch: 7500 G-error: 0.7749287486076355 D-error: 1.2787209749221802
Epoch: 7600 G-error: 0.7246143817901611 D-error: 1.2951247692108154
Epoch: 7700 G-error: 0.7798743844032288 D-error: 1.2761452198028564
Epoch: 7800 G-error: 0.7947475910186768 D-error: 1.2763099670410156
Epoch: 7900 G-error: 0.7237026691436768 D-error: 1.3649399280548096
Epoch: 8000 G-error: 0.7178916335105896 D-error: 1.3708447217941284
Epoch: 8100 G-error: 0.48974016308784485 D-error: 1.614924669265747
Epoch: 8200 G-error: 0.6930974125862122 D-error: 1.316532850265503
Epoch: 8300 G-error: 0.8314673900604248 D-error: 1.2151647806167603
Epoch: 8400 G-error: 0.8028542995452881 D-error: 1.3136051893234253
Epoch: 8500 G-error: 0.8043620586395264 D-error: 1.2830970287322998
Epoch: 8600 G-error: 0.6446516513824463 D-error: 1.4360864162445068
Epoch: 8700 G-error: 0.7298987507820129 D-error: 1.4023547172546387
Epoch: 8800 G-error: 0.7743549942970276 D-error: 1.307661533355713
Epoch: 8900 G-error: 0.832358717918396 D-error: 1.2695221900939941
Epoch: 9000 G-error: 0.7618780136108398 D-error: 1.2711315155029297
Epoch: 9100 G-error: 0.8279746174812317 D-error: 1.2485607862472534
Epoch: 9200 G-error: 0.8576734066009521 D-error: 1.2466458082199097
Epoch: 9300 G-error: 0.5892800092697144 D-error: 1.4986191987991333
Epoch: 9400 G-error: 0.726184070110321 D-error: 1.3784558773040771
Epoch: 9500 G-error: 0.4893801808357239 D-error: 1.636099934577942
Epoch: 9600 G-error: 0.6366032361984253 D-error: 1.3925013542175293
Epoch: 9700 G-error: 0.8547666668891907 D-error: 1.2207987308502197
Epoch: 9800 G-error: 0.8240447044372559 D-error: 1.2517725229263306
Epoch: 9900 G-error: 0.7849145531654358 D-error: 1.3657300472259521
Epoch: 10000 G-error: 0.6545589566230774 D-error: 1.392699122428894
Epoch: 10100 G-error: 0.753450334072113 D-error: 1.3954832553863525
Epoch: 10200 G-error: 0.7478339672088623 D-error: 1.3215458393096924
Epoch: 10300 G-error: 0.8579840660095215 D-error: 1.2455781698226929
Epoch: 10400 G-error: 0.7582443952560425 D-error: 1.364048719406128
Epoch: 10500 G-error: 0.7884209156036377 D-error: 1.3171347379684448
Epoch: 10600 G-error: 0.5440790057182312 D-error: 1.511678695678711
Epoch: 10700 G-error: 0.8052987456321716 D-error: 1.3198603391647339
Epoch: 10800 G-error: 0.47146373987197876 D-error: 1.6872408390045166
Epoch: 10900 G-error: 0.8111518621444702 D-error: 1.2205688953399658
Epoch: 11000 G-error: 0.8577250838279724 D-error: 1.2706319093704224
Epoch: 11100 G-error: 0.7910346388816833 D-error: 1.2938024997711182
Epoch: 11200 G-error: 0.5078915357589722 D-error: 1.606333613395691
Epoch: 11300 G-error: 0.8614342212677002 D-error: 1.2128467559814453
Epoch: 11400 G-error: 0.7473271489143372 D-error: 1.3658878803253174
Epoch: 11500 G-error: 0.6256973147392273 D-error: 1.4762961864471436
Epoch: 11600 G-error: 0.7898271679878235 D-error: 1.334782361984253
Epoch: 11700 G-error: 0.6625677347183228 D-error: 1.3901045322418213
Epoch: 11800 G-error: 0.6307127475738525 D-error: 1.4879541397094727
Epoch: 11900 G-error: 0.7958603501319885 D-error: 1.2344019412994385
Epoch: 12000 G-error: 0.8463006615638733 D-error: 1.2973637580871582
Epoch: 12100 G-error: 0.6726781129837036 D-error: 1.3982057571411133
Epoch: 12200 G-error: 0.6061651110649109 D-error: 1.498374581336975
Epoch: 12300 G-error: 0.4993939697742462 D-error: 1.6161072254180908
Epoch: 12400 G-error: 0.7151942253112793 D-error: 1.380861759185791
Epoch: 12500 G-error: 0.7709327340126038 D-error: 1.286949634552002
Epoch: 12600 G-error: 0.8448091149330139 D-error: 1.2270656824111938
Epoch: 12700 G-error: 0.70270836353302 D-error: 1.3482048511505127
Epoch: 12800 G-error: 0.6127437949180603 D-error: 1.5234178304672241
Epoch: 12900 G-error: 0.827381432056427 D-error: 1.1901936531066895
Epoch: 13000 G-error: 0.8465868234634399 D-error: 1.3025401830673218
Epoch: 13100 G-error: 0.6320017576217651 D-error: 1.4729304313659668
Epoch: 13200 G-error: 0.698807954788208 D-error: 1.3574392795562744
Epoch: 13300 G-error: 0.8613673448562622 D-error: 1.2320353984832764
Epoch: 13400 G-error: 0.8397867679595947 D-error: 1.2847434282302856
Epoch: 13500 G-error: 0.6293087601661682 D-error: 1.4465664625167847
Epoch: 13600 G-error: 0.6744344234466553 D-error: 1.4253040552139282
Epoch: 13700 G-error: 0.8426110744476318 D-error: 1.21433424949646
Epoch: 13800 G-error: 0.8787907958030701 D-error: 1.2313505411148071
Epoch: 13900 G-error: 0.7544004321098328 D-error: 1.3990637063980103
Epoch: 14000 G-error: 0.6484606862068176 D-error: 1.359731912612915
Epoch: 14100 G-error: 0.8698192238807678 D-error: 1.2464543581008911
Epoch: 14200 G-error: 0.7279305458068848 D-error: 1.3861443996429443
Epoch: 14300 G-error: 0.8611162304878235 D-error: 1.196054220199585
Epoch: 14400 G-error: 0.668601930141449 D-error: 1.4025166034698486
Epoch: 14500 G-error: 0.5335410237312317 D-error: 1.579298734664917
Epoch: 14600 G-error: 0.8189011216163635 D-error: 1.3000364303588867
Epoch: 14700 G-error: 0.6250894665718079 D-error: 1.5215225219726562
Epoch: 14800 G-error: 0.8506127595901489 D-error: 1.1619644165039062
Epoch: 14900 G-error: 0.8824940323829651 D-error: 1.232287883758545
Epoch: 15000 G-error: 0.8017228841781616 D-error: 1.3043038845062256
Epoch: 15100 G-error: 0.9232498407363892 D-error: 1.183302402496338
Epoch: 15200 G-error: 0.5475579500198364 D-error: 1.5669679641723633
Epoch: 15300 G-error: 0.7445740103721619 D-error: 1.25714910030365
Epoch: 15400 G-error: 0.9207112193107605 D-error: 1.150181770324707
Epoch: 15500 G-error: 0.7583397626876831 D-error: 1.4313806295394897
Epoch: 15600 G-error: 0.705367922782898 D-error: 1.404977560043335
Epoch: 15700 G-error: 0.684800922870636 D-error: 1.3692673444747925
Epoch: 15800 G-error: 0.9516168236732483 D-error: 1.1945631504058838
Epoch: 15900 G-error: 0.7967294454574585 D-error: 1.281465768814087
Epoch: 16000 G-error: 0.7406124472618103 D-error: 1.3460381031036377
Epoch: 16100 G-error: 0.8581329584121704 D-error: 1.1484026908874512
Epoch: 16200 G-error: 0.8888218402862549 D-error: 1.3010917901992798
Epoch: 16300 G-error: 0.5109542012214661 D-error: 1.6048986911773682
Epoch: 16400 G-error: 0.9539903998374939 D-error: 1.1611313819885254
Epoch: 16500 G-error: 0.877132773399353 D-error: 1.2333979606628418
Epoch: 16600 G-error: 0.6564957499504089 D-error: 1.4726498126983643
Epoch: 16700 G-error: 0.734095573425293 D-error: 1.3695894479751587
Epoch: 16800 G-error: 0.6721248626708984 D-error: 1.3914730548858643
Epoch: 16900 G-error: 0.8282356858253479 D-error: 1.2323448657989502
Epoch: 17000 G-error: 0.782503068447113 D-error: 1.379913330078125
Epoch: 17100 G-error: 0.5938522815704346 D-error: 1.4651542901992798
Epoch: 17200 G-error: 0.6008198857307434 D-error: 1.4745818376541138
Epoch: 17300 G-error: 0.5533403754234314 D-error: 1.5077459812164307
Epoch: 17400 G-error: 0.802611768245697 D-error: 1.3664743900299072
Epoch: 17500 G-error: 0.5570070743560791 D-error: 1.5475350618362427
Epoch: 17600 G-error: 0.8245347738265991 D-error: 1.2461507320404053
Epoch: 17700 G-error: 0.15060880780220032 D-error: 2.6097216606140137
Epoch: 17800 G-error: 0.7422105073928833 D-error: 1.350264549255371
Epoch: 17900 G-error: 0.5384538769721985 D-error: 1.536644697189331
Epoch: 18000 G-error: 0.9339253902435303 D-error: 1.158017635345459
Epoch: 18100 G-error: 0.713828444480896 D-error: 1.4121134281158447
Epoch: 18200 G-error: 0.928147554397583 D-error: 1.1372759342193604
Epoch: 18300 G-error: 0.6832849383354187 D-error: 1.476395845413208
Epoch: 18400 G-error: 0.6974059343338013 D-error: 1.34727144241333
Epoch: 18500 G-error: 0.8570052981376648 D-error: 1.1862518787384033
Epoch: 18600 G-error: 0.8899285197257996 D-error: 1.2711997032165527
Epoch: 18700 G-error: 0.3678674101829529 D-error: 1.7925114631652832
Epoch: 18800 G-error: 0.9243385195732117 D-error: 1.1952717304229736
Epoch: 18900 G-error: 0.455410361289978 D-error: 1.7024316787719727
Epoch: 19000 G-error: 0.8931671380996704 D-error: 1.1625603437423706
Epoch: 19100 G-error: 0.5997143983840942 D-error: 1.5655604600906372
Epoch: 19200 G-error: 0.9446706175804138 D-error: 1.063122034072876
Epoch: 19300 G-error: 0.9037319421768188 D-error: 1.251272201538086
Epoch: 19400 G-error: 0.6224573254585266 D-error: 1.5075581073760986
Epoch: 19500 G-error: 0.9315957427024841 D-error: 1.1297528743743896
Epoch: 19600 G-error: 0.5498090386390686 D-error: 1.57270085811615
Epoch: 19700 G-error: 0.9656845927238464 D-error: 1.098314642906189
Epoch: 19800 G-error: 0.6372727155685425 D-error: 1.4881141185760498
Epoch: 19900 G-error: 0.9195835590362549 D-error: 1.1867170333862305

batch_size = real_batch.shape[0]

    # 1. Train Discriminator
    real_data = real_batch
    # Generate fake data
    noise = torch.randn(batch_size, 100)
    fake_data = generator(noise).detach()
    # Train D
    d_error = train_discriminator(d_optimizer, real_data, fake_data)

    # 2. Train Generator
    # Generate fake data
    noise = torch.randn(batch_size, 100)
    fake_data = generator(noise)
    # Train G
    g_error = train_generator(g_optimizer, fake_data)
    # Log error

Final result

In [13]:
angle = np.random.uniform(-np.pi,np.pi,(sample_size,1)).astype('float32')
data = np.concatenate((np.cos(angle), np.sin(angle)),axis=1)
real_data = t.from_numpy(data)
    
    #fake data:
noise = t.empty(sample_size,2, device=device).uniform_(-1,1)
fake_data = generator(noise).detach()

show(noise, fake_data, real_data)
In [ ]:
 

Problem 1

Implement the GAN train loop that will train GAN to generate from the sample distribution.

Problem 2

Use another sampling distribution. One that is not concentrated on a line e.g. an ellipse.

In [ ]: